"""
Metrics module for cognitive map evaluation.

This module provides functions to:
1. Calculate similarity between cognitive maps
2. Calculate directional and facing similarity
3. Check for isomorphism between graphs
"""

import numpy as np
from typing import Dict, List, Tuple, Optional, Any, Union, Set

from .graph_operations import (
    extract_objects_with_extended_info,
    build_comprehensive_relation_matrix,
    get_rotation_matrices,
    apply_rotation_to_map
)

from .extraction import (
    validate_cogmap_format,
    is_complex_format,
    trucate_object_position
)

def calculate_cogmap_similarity(generated_map: Dict, grounded_map: Dict) -> Dict:
    """
    Calculate similarity between generated and grounded cognitive maps.
    Supports inner/outer relationships and 3D rotation invariance.
    
    Args:
        generated_map: Generated cognitive map
        grounded_map: Ground truth cognitive map
        
    Returns:
        Dictionary of similarity metrics
    """
    if not generated_map or not grounded_map:
        return _empty_similarity_result()
    
    # Call the extended evaluation function
    extended_result = calculate_extended_cogmap_similarity(generated_map, grounded_map)
    
    # Map results to the original metric names for backward compatibility
    result = {
        "isomorphic": extended_result["rotation_invariant_isomorphic"],
        "rotation_invariant_isomorphic": extended_result["rotation_invariant_isomorphic"],
        "position_similarity": extended_result["directional_similarity"],
        "facing_similarity": extended_result["facing_similarity"],
        "directional_similarity": extended_result["directional_similarity"],
        "relative_position_accuracy": extended_result["directional_similarity"],
        "overall_similarity": extended_result["overall_similarity"],
        "valid_graph": extended_result["valid_graph"],
        "parsable_json": extended_result.get("parsable_json", True),
        "valid_format": extended_result.get("valid_format", False),
        "coverage": extended_result["coverage"],
        "best_rotation": extended_result["best_rotation"]
    }
    
    return result

def _empty_similarity_result() -> Dict:
    """
    Returns an empty similarity result with default values.
    
    Returns:
        Dictionary with default metrics
    """
    return {
        "isomorphic": False,
        "rotation_invariant_isomorphic": False,
        "position_similarity": 0.0,
        "facing_similarity": 0.0,
        "directional_similarity": 0.0,
        "relative_position_accuracy": 0.0,
        "overall_similarity": 0.0,
        "valid_graph": False,
        "parsable_json": False,
        "valid_format": False,
        "coverage": 0.0,
        "best_rotation": None
    }

def calculate_extended_cogmap_similarity(generated_map: Dict, grounded_map: Dict) -> Dict:
    """
    Calculate similarity between generated and grounded cognitive maps.
    Supports inner/outer relationships and 3D rotation invariance.
    Handles both simple format (only objects) and complex format (objects and views).
    
    Args:
        generated_map: Generated cognitive map
        grounded_map: Ground truth cognitive map
        
    Returns:
        Dictionary of extended similarity metrics
    """
    # Create an empty result structure to build upon
    result = _empty_extended_similarity_result()
    
    # Check if inputs are None or empty
    if not generated_map or not grounded_map:
        result["parsable_json"] = False
        return result
    
    # Ensure the inputs are dictionaries
    if not isinstance(generated_map, dict) or not isinstance(grounded_map, dict):
        result["parsable_json"] = False
        return result
    
    # Mark as parsable JSON since we've confirmed they are dictionaries
    result["parsable_json"] = True
    
    # Apply truncation to simple format maps if needed
    if not is_complex_format(generated_map):
        generated_map = trucate_object_position(generated_map)
    
    # Validate format correctness
    gen_valid_format, gen_errors = validate_cogmap_format(generated_map)
    ground_valid_format, ground_errors = validate_cogmap_format(grounded_map)
    
    # Record validation result
    result["valid_format"] = gen_valid_format and ground_valid_format
    
    # 如果格式无效，提前返回结果
    if not result["valid_format"]:
        return result
    
    # 第三层检查：提取对象和位置
    gen_data = extract_objects_with_extended_info(generated_map)
    ground_data = extract_objects_with_extended_info(grounded_map)
    
    # If either map has no objects with valid positions, return failed result
    if not gen_data or not ground_data:
        return result
    
    # Mark as valid graph since we have extracted objects
    result["valid_graph"] = True
    
    # Determine if generated map is simple or complex format
    is_gen_complex = "views" in generated_map if isinstance(generated_map, dict) else False
    
    # For complex format grounded map, filter objects based on generated map format
    ground_objects_set = set(ground_data.keys())
    gen_objects_set = set(gen_data.keys())
    
    # 如果生成地图是简单格式，只考虑grounded地图中的对象（排除视图）
    if not is_gen_complex and isinstance(grounded_map, dict) and "objects" in grounded_map:
        ground_object_names = {obj["name"] for obj in grounded_map.get("objects", []) if "name" in obj}
        ground_objects_set = ground_object_names
    
    # 计算ground truth对象在生成地图中的覆盖率
    common_objects = ground_objects_set & gen_objects_set
    coverage = len(common_objects) / len(ground_objects_set) if ground_objects_set else 0
    result["coverage"] = coverage
    result["common_objects"] = list(common_objects)
    
    # 如果没有共有对象，地图无法比较
    if not common_objects:
        # 已经设置了valid_graph为True，但没有共有对象无法比较相似度
        return result
    
    # 建立ground truth的关系矩阵
    ground_relations = build_comprehensive_relation_matrix(ground_data, list(ground_objects_set))
    
    # 尝试不同的旋转来找到最佳匹配
    best_similarity = 0.0
    best_rotation = None
    best_directional_sim = 0.0
    best_facing_sim = 0.0
    rotation_invariant_isomorphic = False
    
    # 获取旋转矩阵
    rotations = get_rotation_matrices()
    
    # 测试旋转
    for rotation in rotations:
        try:
            # 将旋转应用于生成地图
            rotated_gen_data = apply_rotation_to_map(gen_data, rotation)
            
            # 使用所有gen对象为旋转数据构建关系矩阵
            gen_relations = build_comprehensive_relation_matrix(rotated_gen_data, list(gen_objects_set))
            
            # 检查同构性 - 生成地图必须包含所有ground truth关系
            is_isomorphic = check_rotation_invariant_isomorphism(gen_relations, ground_relations)
            
            # 计算方向相似性 - 有多少ground truth关系在生成地图中正确表示
            total_ground_relations = 0
            matching_relations = 0
            
            for obj1 in ground_objects_set:
                if obj1 not in gen_objects_set:
                    continue
                    
                for obj2 in ground_objects_set:
                    if obj2 not in gen_objects_set or obj1 == obj2:
                        continue
                        
                    ground_rel = ground_relations.get(obj1, {}).get(obj2)
                    gen_rel = gen_relations.get(obj1, {}).get(obj2)
                    
                    if ground_rel is not None:
                        total_ground_relations += 1
                        if gen_rel == ground_rel:
                            matching_relations += 1
            
            directional_sim = matching_relations / total_ground_relations if total_ground_relations > 0 else 0.0
            
            # 计算朝向相似性
            total_facings = 0
            matching_facings = 0
            
            for obj in ground_objects_set:
                if obj not in gen_objects_set:
                    continue
                    
                ground_facing = ground_data[obj]["facing"]
                gen_facing = rotated_gen_data[obj]["facing"]
                
                if ground_facing:
                    total_facings += 1
                    if gen_facing == ground_facing:
                        matching_facings += 1
            
            facing_sim = matching_facings / total_facings if total_facings > 0 else 1.0
            
            # 总体相似性
            overall_sim = 0.7 * directional_sim + 0.3 * facing_sim
            
            # 跟踪最佳旋转
            if overall_sim > best_similarity:
                best_similarity = overall_sim
                best_rotation = rotation
                best_directional_sim = directional_sim
                best_facing_sim = facing_sim
                rotation_invariant_isomorphic = is_isomorphic
        except Exception as e:
            print(f"Error during rotation {rotation['name']}: {e}")
            continue
    
    # 用计算出的值更新结果
    result["rotation_invariant_isomorphic"] = rotation_invariant_isomorphic
    result["directional_similarity"] = best_directional_sim
    result["facing_similarity"] = best_facing_sim
    result["overall_similarity"] = best_similarity
    result["best_rotation"] = best_rotation
    
    return result

def _empty_extended_similarity_result() -> Dict:
    """
    Returns an empty extended similarity result with default values.
    
    Returns:
        Dictionary with default extended metrics
    """
    return {
        "rotation_invariant_isomorphic": False,
        "directional_similarity": 0.0,
        "facing_similarity": 0.0,
        "overall_similarity": 0.0,
        "valid_graph": False,
        "parsable_json": False,
        "valid_format": False,
        "coverage": 0.0,
        "best_rotation": None,
        "common_objects": []
    }

def check_rotation_invariant_isomorphism(gen_relations: Dict, 
                                       ground_relations: Dict) -> bool:
    """
    Check if the generated map contains all relationships from the ground truth map.
    The generated map can have additional objects, but must contain all ground truth objects
    and their relationships.
    
    Args:
        gen_relations: Generated map relation matrix
        ground_relations: Ground truth relation matrix
        
    Returns:
        True if the generated map contains all ground truth relationships, False otherwise
    """
    # Ensure all objects in ground relations exist in gen relations
    for obj1, obj1_relations in ground_relations.items():
        if obj1 not in gen_relations:
            return False
        
        for obj2, relation in obj1_relations.items():
            if obj2 not in gen_relations:
                return False
            
            gen_relation = gen_relations.get(obj1, {}).get(obj2)
            if gen_relation != relation:
                return False
    
    return True




# ==============================
# --------- TESTS ---------
# ==============================


def test_calculate_cogmap_similarity():
    print("========== Testing calculate_cogmap_similarity ==========")
    # Complex format with objects and views
    json_obj_1 = {
        "objects": [
            {"name": "object1", "position": [1, 2], "facing": "up"},
            {"name": "object2", "position": [3, 4], "facing": "down"},
            {"name": "object3", "position": [5, 6]},
        ],
        "views": [
            {"name": "view1", "position": [1, 2], "facing": "up"},
            {"name": "view2", "position": [3, 4], "facing": "down"}
        ]
    }
    # Complex format with name difference
    json_obj_2 = {
        "objects": [
            {"name": "objectfhjgkhjgkhjg", "position": [1, 2], "facing": "up"},
            {"name": "object2", "position": [3, 4], "facing": "down"},
            {"name": "object3", "position": [5, 6]},
        ],
        "views": [
            {"name": "view1", "position": [1, 2], "facing": "up"},
            {"name": "view2", "position": [3, 4], "facing": "down"}
        ]
    }
    # Simple format with only objects
    json_obj_3 = {
        "objects": [
            {"name": "object1", "position": [1, 2], "facing": "up"},
            {"name": "object2", "position": [3, 4], "facing": "down"},
            {"name": "object3", "position": [5, 6]},
        ],
    }
    
    # Rotated map - objects are in different positions but relative positions are preserved
    json_obj_4 = {
        "objects": [
            {"name": "object1", "position": [2, -1], "facing": "right"},  # Rotated 90 degrees clockwise
            {"name": "object2", "position": [4, -3], "facing": "left"},   # Rotated 90 degrees clockwise
            {"name": "object3", "position": [6, -5]},                     # Rotated 90 degrees clockwise
        ],
        "views": [
            {"name": "view1", "position": [2, -1], "facing": "right"},    # Rotated 90 degrees clockwise
            {"name": "view2", "position": [4, -3], "facing": "left"}      # Rotated 90 degrees clockwise
        ]
    }
    
    # Test complex format comparison
    print("Complex vs Complex (different names):")
    print(calculate_cogmap_similarity(json_obj_1, json_obj_2))
    
    # Test simple format vs complex format
    print("\nSimple vs Complex:")
    print(calculate_cogmap_similarity(json_obj_3, json_obj_1))
    
    # Test complex format vs simple format
    print("\nComplex vs Simple:")  
    print(calculate_cogmap_similarity(json_obj_1, json_obj_3))
    
    # Test rotation invariance - rotated map vs original map
    print("\nRotated map vs Original map (rotation invariance test):")
    result = calculate_cogmap_similarity(json_obj_4, json_obj_1)
    print(result)
    if result["best_rotation"]:
        print(f"Best rotation found: {result['best_rotation']['name']}")


if __name__ == "__main__":
    test_calculate_cogmap_similarity()
